Skip to content

[Core] Allow full cudagraph with separate attention routines and orthogonal to compilation, add support for FA2 and FlashInfer #20059

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 45 commits into
base: main
Choose a base branch
from

Conversation

fhl2000
Copy link
Contributor

@fhl2000 fhl2000 commented Jun 25, 2025

Purpose

1. This PR introduces a new implementation for full cuda graph, and adds support for FA2 and FlashInfer.

Previous limitations

The original design in PR #16072 is to set compilation_config.splitting_ops as an empty list and capture the full cudagraph inside the flattened fx graph, which supports FA3 only. In later PR #18581, full cudagraph support for FlashMLA only captures the pure decode stage, and bypasses the mix prefill-decode stages, i.e., it runs the eager code of the compiled flattened fx graph in this stage. However, from the profiling results(see below), I found this flattened graph has performance issues at eager call, which is about 2x slower on the cpu side than the compiled piecewise fx graph running (possibly an issue from Python). This can lead to potential performance degradation when the prefill stage of a small batch size.

Also, considering that attention backends, like FA2, FlashInfer, and FlashMLA, have two distinct attention routines for prefill-decode stages and pure decode stages separately, which makes it difficult to contain all in a unified graph and only keeps one set of captured cudagraphs.

Solution of this PR.

So, the new trick is, we keep the piecewise compiled fx graph structure overall, but capture the full cudagraph outside the fx graph via a wrapper. With this at hand, we can dispatch to two sets of cudagraph. For the pure decode stage, directly using full cudagraphs since it is compatible with most attention backends. For mix prefill-decode stages, it can either fall back to piecewise cudagraph for incompatible routines in backends like FlashMLA and FlashInfer, or to use another set of full cudagraph for compatible backends(varlen supports in FA2).

Note that keeping the piecewise compiled fx graph is at least better than a full but flattened one from the viewpoint of reducing cpu overhead, even if we do not capture the mix prefill-decode stage. It is also flexible to switch between full cudagraph and piecewise cudagraph for future extension. For example, seamless fallback to piecewise cudagraph if cascade attention is needed.

The limitation is the increased startup time and more gpu memory required for the additional cudagraph capturing. Maybe we can optimize this by shrinking the list of batch sizes to be captured for the prefill-decode stage.

#profile on compiled flatten fx graph on eager execution, mix prefill-decode stage.

Takes roughly 56ms to fully launch the model. An additional 5ms latency in doing some safety checking before launching the first kernel. It seems Python is slow at executing the flattened and large module without submodules.
image

Note: the only way to use flatten fx graph in this PR is to hardcode the splitting_ops =[] in set_splitting_ops_for_v1 (around line 4200 in vllm/config.py)

#profile on compiled piecewise fx graph on eager execution, mix prefill-decode stage.

28 ms to fully launch, and the latency above almost disappears. In fact, they are hidden inside each submodule.
image

The patterns above are verified on two different machines (ignoring the gpu difference here as this is only related to cpu), tested on Qwen2.5-7B-Instruct-GPTQ-Int4 and profile benchmark_serving (sharegpt, unlimited request rate).

So, if a prefill batch size is a bit larger than the max capturing size (say 512) but not too big, the lower bound of model forward time is possibly bounded by cpu side, around 56ms in running the flattened graph, instead of 28ms for the piecewise one.

Details for supporting FA2:

The previous codes did not recognize the two routines under the FA2 code. It launches a standard varlen fwd kernel on mix prefill-decode batches. or launches another routine for pure decode batches, including an optimization for GQA/MQA and potential flash-decode kernels (split_kv >1). By setting max_query_len =1 or >1 on cuda capturing phase, we can correctly activate the desired attention routine, therefore to be correctly captured. (To be serious, the kernel for prefill-decode phase is, of course, compatible with pure decode, but is not fully optimized for decode phase. The actual reason PR #16072 did not support FA2 is a bug that the seq_lens is a zero tensor in the dummy_run in the early code, which bypasses launching any attention kernel at the capturing phase, leading to zero tensor outputs.)

  • FA2 runs both mix prefill-decode and pure decode batches at full cudagraph, but on two separate sets of cudagraphs.

Details for supporting FlashInfer:

  • Using the persistent buffer trick.
  • Create many decode_warpers, one for a cudagraph batch size, as this is required by the FlashInfer API.
  • Run pure decode batches at full cudagraph, and fall back to piecewise cudagraph at mix prefill-decode batches.

Launching command examples:

For FA2:

VLLM_FLASH_ATTN_VERSION=2 python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 --compilation-config '{"full_cuda_graph":true, "separate_attention_routine":true}'

For FlashInfer:

VLLM_ATTENTION_BACKEND=FLASHINFER python -m ... --compilation-config '{"full_cuda_graph":true,"separate_attention_routine":true}'

others:
FlashMLA: the compilation-config is '{"full_cuda_graph":true,"separate_attention_routine":true}'
FA3: env set VLLM_FLASH_ATTN_VERSION=3 and the compilation-config is '{"full_cuda_graph":true}'

Test Plan

benchmark serving, lm_eval performance of FA2 and FlashInfer

I have no plan to test FlashMLA and FA3 as no hopper gpu at hand, but it should be fine as the current design is compatible with them. However, it would be very nice if somebody could help test them.

Test Result

Summary of results

Output token throughput is imporved by 5% for FA2 and 2% for FlashInfer on Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4. TPOT is reduced by 2.9% and 3.1%, respectively. The lm_evel has no changes for both.

Details

machine: A100 40G, torch2.6 cuda12.4

Benchmark serving command:

python benchmarks/benchmark_serving.py --model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 --dataset-name sharegpt --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 100 --request-rate 20

FA2 benchmark serving:

piecewise cudagraph before this PR

python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 --gpu-memory-utilization 0.9

============ Serving Benchmark Result ============
Successful requests: 100
Benchmark duration (s): 11.41
Total input tokens: 23260
Total generated tokens: 21657
Request throughput (req/s): 8.77
Output token throughput (tok/s): 1898.67
Total Token throughput (tok/s): 3937.88
---------------Time to First Token----------------
Mean TTFT (ms): 76.37
Median TTFT (ms): 71.08
P99 TTFT (ms): 191.53
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 17.08
Median TPOT (ms): 15.22
P99 TPOT (ms): 67.68
---------------Inter-token Latency----------------
Mean ITL (ms): 13.45
Median ITL (ms): 11.05
P99 ITL (ms): 72.61
==================================================

full cudagraph + piecewise fx graph in this PR

python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 --gpu-memory-utilization 0.9 --compilation-config '{"full_cuda_graph": true,"separate_attention_routine": true}'

============ Serving Benchmark Result ============
Successful requests: 100
Benchmark duration (s): 10.87
Total input tokens: 23260
Total generated tokens: 21657
Request throughput (req/s): 9.20
Output token throughput (tok/s): 1992.27
Total Token throughput (tok/s): 4132.01
---------------Time to First Token----------------
Mean TTFT (ms): 78.69
Median TTFT (ms): 75.10
P99 TTFT (ms): 195.90
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 16.57
Median TPOT (ms): 14.78
P99 TPOT (ms): 78.21
---------------Inter-token Latency----------------
Mean ITL (ms): 12.83
Median ITL (ms): 10.34
P99 ITL (ms): 72.37
==================================================

FA2 lm_eval

piecewise cudagraph before this PR

vllm ({'pretrained': '/root/models/Qwen2.5-7B-Instruct-GPTQ-Int4', 'gpu_memory_utilization': 0.9}), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8074 ± 0.0109
strict-match 5 exact_match 0.7619 ± 0.0117

full cudagraph + piecewise fx graph after this PR

vllm ({'pretrained': '/root/models/Qwen2.5-7B-Instruct-GPTQ-Int4', 'gpu_memory_utilization': 0.9, 'compilation_config': {'full_cuda_graph': True, 'separate_attention_routine': True}}), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8074 ± 0.0109
strict-match 5 exact_match 0.7619 ± 0.0117

FlashInfer benchmark serving

piecewise cudagraph before this PR

VLLM_ATTENTION_BACKEND=FLASHINFER python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 --gpu-memory-utilization 0.9

============ Serving Benchmark Result ============
Successful requests: 100
Benchmark duration (s): 11.36
Total input tokens: 23260
Total generated tokens: 21660
Request throughput (req/s): 8.81
Output token throughput (tok/s): 1907.38
Total Token throughput (tok/s): 3955.65
---------------Time to First Token----------------
Mean TTFT (ms): 73.61
Median TTFT (ms): 69.59
P99 TTFT (ms): 184.62
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 16.85
Median TPOT (ms): 15.13
P99 TPOT (ms): 65.75
---------------Inter-token Latency----------------
Mean ITL (ms): 13.34
Median ITL (ms): 11.09
P99 ITL (ms): 71.82
==================================================

full cudagraph + piecewise fx graph after this PR

VLLM_ATTENTION_BACKEND=FLASHINFER python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 --gpu-memory-utilization 0.9 --compilation-config '{"full_cuda_graph": true,"separate_attention_routine": true}'

============ Serving Benchmark Result ============
Successful requests: 100
Benchmark duration (s): 11.13
Total input tokens: 23260
Total generated tokens: 21660
Request throughput (req/s): 8.99
Output token throughput (tok/s): 1946.35
Total Token throughput (tok/s): 4036.48
---------------Time to First Token----------------
Mean TTFT (ms): 76.03
Median TTFT (ms): 67.04
P99 TTFT (ms): 192.56
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 16.34
Median TPOT (ms): 14.96
P99 TPOT (ms): 58.86
---------------Inter-token Latency----------------
Mean ITL (ms): 13.11
Median ITL (ms): 10.71
P99 ITL (ms): 71.69
==================================================

FlashInfer lm_eval

piecewise cudagraph before this PR

vllm ({'pretrained': '/root/models/Qwen2.5-7B-Instruct-GPTQ-Int4', 'gpu_memory_utilization': 0.9}), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8105 ± 0.0108
strict-match 5 exact_match 0.7635 ± 0.0117

full cudagraph + piecewise fx graph after this PR

vllm ({'pretrained': '/root/models/Qwen2.5-7B-Instruct-GPTQ-Int4', 'gpu_memory_utilization': 0.9, 'compilation_config': {'full_cuda_graph': True, 'separate_attention_routine': True}}), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8105 ± 0.0108
strict-match 5 exact_match 0.7635 ± 0.0117

One more thing, after merging some code from the main branch recently, I ran into a potential deadlock when testing this PR. This should be caused by an early merged code, and PR #19927 seems to solve the problem.

fhl2000 added 2 commits June 25, 2025 13:36
Signed-off-by: fhl <2410591650@qq.com>
Signed-off-by: fhl <2410591650@qq.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @fhl2000, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a refined approach to full CUDA graph integration within vLLM, moving from a flattened FX graph to a wrapper-based strategy that preserves the piecewise graph structure. This new method facilitates broader full CUDA graph support for attention backends like FlashAttention 2 and FlashInfer, leading to measurable performance gains. Additionally, it includes a fix for a specific kernel compilation issue, enhancing overall system stability and compatibility.

Highlights

  • Enhanced Full CUDA Graph Implementation: Introduces a new strategy for full CUDA graph capture that wraps the piecewise compiled FX graph, rather than flattening it. This aims to reduce CPU overhead for non-captured batch sizes and offers greater flexibility, allowing dispatch to different CUDA graph sets for prefill-decode and pure decode stages.
  • FA2 and FlashInfer Support: Extends full CUDA graph support to FlashAttention 2 (FA2) and FlashInfer backends. This includes specific adaptations for their distinct prefill-decode and pure decode routines, enabling performance benefits for these attention backends.
  • Performance Improvements: Benchmarking results indicate a 5% improvement in output token throughput for FA2 and a 2% improvement for FlashInfer, with corresponding reductions in Time Per Output Token (TPOT) by 2.9% and 3.1% respectively.
  • Marlin Kernel Compilation Bug Fix: Addresses a minor bug where Marlin kernels were incorrectly compiled for unsupported GPU architectures (e.g., 8.7 for RTX 4090, which is 8.9), resolving 'RuntimeError: CUDA error: no kernel image is available for execution on the device' errors.
  • Separate Attention Routine Configuration: Adds a new separate_attention_routine flag to CompilationConfig, allowing for distinct CUDA graph capturing for prefill-decode and pure decode stages within attention backends that implement different branches for these cases.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new implementation for full cuda graph, adds support for FA2 and FlashInfer, and fixes a bug for Marlin kernels on Ada architecture. The core idea is to keep the piecewise graph structure and wrap it to capture the full CUDA graph, which avoids the CPU overhead of a large flattened graph. The changes are well-motivated, and the performance improvements are clearly demonstrated.

Signed-off-by: fhl <2410591650@qq.com>
@fhl2000 fhl2000 force-pushed the full_cudagraph_FA2_FlashInfer branch from bcf7cb9 to c2c5fea Compare June 25, 2025 08:33
Copy link

mergify bot commented Jun 25, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fhl2000.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 25, 2025
fhl2000 and others added 2 commits June 25, 2025 16:52
@mergify mergify bot removed the needs-rebase label Jun 25, 2025
fhl2000 added 2 commits June 25, 2025 10:03
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
@fhl2000
Copy link
Contributor Author

fhl2000 commented Jun 25, 2025

I have incorporated some checks for the new flag separate_attention_routine, so it is safe to launch now. This PR is now ready to be reviewed!

@fhl2000 fhl2000 marked this pull request as ready for review June 25, 2025 14:53
@fhl2000 fhl2000 changed the title [Core][Bugfix] new way for full cudagraph, add support for FA2 and FlashInfer; a minor bug fixed [Core][Bugfix] New way for full cudagraph, add support for FA2 and FlashInfer; A minor bug fixed Jun 25, 2025
@fhl2000
Copy link
Contributor Author

fhl2000 commented Jun 26, 2025

Here is the workflow. At the initialization of torch.compile, the vllm_backend will warp the split_gm into a full cudagraph warpper class if compilation_config.full_cuda_graph is on. Then this warper class takes responsibility for dispatching to the cudagraph entries of separate attention routines. At runtime, this dispatching is based on two key flags in the global forward_context, skip_attention_cuda_graphs and is_pure_decoding. While skip_attention_cuda_graphs is true, which implies using full cudagraph, this wrapper class will take care of it. That is, when separate_attention_backend is on, the wrapper class furtherly dispatches to decode-only full cudagraph or mix prefill-decode full cudagraph, according to the is_pure_decoding flag. On the other hand, if skip_attention_cuda_graphs is false, the wrapper class immediately falls back to the piecewise fx graph (the original split_gm), which relies on the CUDAPiecewiseBackend class to take on the piecewise cudagraph logic.

@fhl2000
Copy link
Contributor Author

fhl2000 commented Jun 26, 2025

Here is the workflow. At the initialization of torch.compile, the vllm_backend will warp the split_gm into a full cudagraph warpper class if compilation_config.full_cuda_graph is on. Then this warper class takes responsibility for dispatching to the cudagraph entries of separate attention routines. At runtime, this dispatching is based on two key flags in the global forward_context, skip_attention_cuda_graphs and is_pure_decoding. While skip_attention_cuda_graphs is true, which implies using full cudagraph, this wrapper class will take care of it. That is, when separate_attention_backend is on, the wrapper class furtherly dispatches to decode-only full cudagraph or mix prefill-decode full cudagraph, according to the is_pure_decoding flag. On the other hand, if skip_attention_cuda_graphs is false, the wrapper class immediately falls back to the piecewise fx graph (the original split_gm), which relies on the CUDAPiecewiseBackend class to take on the piecewise cudagraph logic.

Please let me know If any questions or suggestions. I am currently planning on adding some unit tests.

Signed-off-by: fhl <2410591650@qq.com>
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a good approach overall!
My initial feedback:

  • I think we should try to consolidate CUDAGraph logic into a single class.
  • CUDAGraph logic is complex on main already, and this PR increases complexity significantly. We should add significantly more documentation. I also think we should consolidate various config flags and states.
  • There are benefits to compilation without splitting the graph (e.g. attention+quant fusion). We should add a new flag that maintains that ability (and assert the attention backend supports full cudagraph only). CUDAGraph logic can stay in the wrapper class.
  • This is a large PR, so it might help to split it. e.g. FlashInfer cg support can be added in a follow-up. But I'll let others chime in here.

Okay, this is plenty for now :D - thanks for the PR!

vllm/config.py Outdated
@@ -3984,6 +3984,14 @@ class CompilationConfig:
splitting certain operations such as attention into subgraphs. Thus this
flag cannot be used together with splitting_ops. This may provide
performance benefits for smaller models."""
separate_attention_routine: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be named better. Perhaps split_attn_cudagraph? I also don't understand why this has to be a flag and we can't just ask the attention backend what it wants?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we must leave such a flag in the global config, which tells the compiler backend to do the right thing. Otherwise, how is the attention backend supposed to communicate its requirements to the compiler? At least for now, the force_separate_routine flag of an attention backend has the ability to enforce its preference during the initialize_attn_backend phase of the gpu model runner.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be named better. Perhaps split_attn_cudagraph?

I am not sure what name can be better. Btw, I'm afraid split_attn_cudagraph is not a good name. It sounds like splitting the full graph into be piecewise graph, where attn ops are the splitting ops, like what we have already done.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call on the name. Also makes sense we use this to communicate from attention backend to compiler. Let's make sure that happens inside set_splitting_ops_for_v1/somewhere inside config initialization, if we can.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should figure out a different name for this; the current name doesnt indicate any relation to cudagraphs

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not as zoned into this PR as you folks are, but I have no clue what this flag is from the name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should figure out a different name for this; the current name doesnt indicate any relation to cudagraphs

How about cudagraph_separate_routine? Cutting the "attention" out seems to have no effect on its meaning. While it is basically prepared for distinct attention routines that are actually executed, in the future, that may be more than just attention ops.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just changed to cudagraph_separate_routine. It should be better.

fhl2000 added 5 commits July 13, 2025 09:35
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice work!!

Comment on lines +2365 to +2369
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
# But be careful, warm up with `NONE`is orthogonal to
# if we want to warm up attention or not. This is
# different from the case where `FULL` implies capture
# attention while `PIECEWISE` implies no attention.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice explanation and the logic is clear, thanks!

Copy link

mergify bot commented Jul 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fhl2000.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 14, 2025
@ProExpertProg
Copy link
Collaborator

Additionally, it would be great to get some tests in, especially for the capture/dispatch logic in CUDAGraphDispatcher and the config initialization. Perhaps we mock the model variable and check the forward context is set correctly? And we should test config initialization for all valid input configuration (and check flags/modes/etc. are adjusted correctly.

@mergify mergify bot removed the needs-rebase label Jul 14, 2025
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
@fhl2000 fhl2000 changed the title [Core][Bugfix] New way for full cudagraph, add support for FA2 and FlashInfer [Core] Allow full cudagraph with separate attention routines and orthogonal to compilation, add support for FA2 and FlashInfer Jul 14, 2025
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution! I do like the idea of making supporting piecewise-cudagraphs and full-cudagraphs in parallel as first class citizen alot; im generally a fan of this direction. Did a first pass; will finish up soon but please do a scrub for typos and I do think we can dramatically simplify the dispatching logic. Right now there are alot different flags which makes it a bit confusing. I think we should basically make it so that we compile full-cudagraphs when we can (if enabled) and if have a compiled full-cudagraph exists we use that if not we fallback on piecewise cudagraphs, i.e. something like:

dispatch_key = DispatchKey(
    num_reqs=...,
    num_tokens=...,
    uniform_batch=...,
)

# 1) Prefer a full CUDA graph if one exists.
if dispatch_key in self.full_cudagraphs:
    return self.full_cudagraphs[dispatch_key]

# 2) Otherwise, fall back to piecewise or direct execution.
return self.piecewise_cudagraph or self.model

the uniform_batch flag here would indicate that all of the requests in the batch have the same number of tokens; so if num_reqs == num_tokens and uniform_batch would be pure decode.

But not using a is_pure_decode flag here would this would leave a door open for spec-decode support the future; i.e. where "decode" steps are validating 2-4ish tokens at the same time. So if we have a speculator set up to speculate 3 tokens at a time we could create full-cudagraphs for 3*num_reqs == num_tokens and uniform_batch. Something like FlashMLA would actually support this since the main thing it wants is a uniform batch.

cc @ProExpertProg

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe we can rename this file in a future PR (that just does file renaming) so we can see the diff here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

cudagraph_runtime_style: CUDAGraphRuntimeStyle
# Be aware that is_pure_decode should be default None
# for both piecewise cudagraphs and no cudagraphs.
is_pure_decode: Optional[bool] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we should name this use_full_cudagraph I could see cases in the future where we might want to mixed piecewise + full-cudagraphs but have a different heuristic than is_pure_decode. e.g. for FA3 we may want to run mixed small decodes or spec-decode using full-cudagraphs but run large prefills using piecewise cudagraphs

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is pure decode is not equivalent to full cudagraphs. It is the same for backends that only support pure decode batches in CG, but not otherwise. While I agree the name could be improved I don't think full_cudagraph is better

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya no I agree; wrote this early in the review, sorry!. Im actually more of fan of uniform_batch now haha, see: #20059 (review)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just updated to uniform_batch (along with some related logic), though I haven't yet tested if it's okay with speculative decoding.

return CUDAGraphRuntimeStyle.NONE

def dispatch(self, cudagraph_runtime_style: CUDAGraphRuntimeStyle,
is_pure_decode: bool) -> Any:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto above; I dont think we over index on is_pure_decode naming and since we may want to dispatch between full-cudagraphs and piecewise on other metrics in the future

# the separate_attention_routine flag, but should inform
# the user that this flag can be turned on to obtain
# better performance.
if attn_cg == AttentionCGSupport.ALWAYS_SEPARATE and \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on this description I think the warning is very confusing to the user. They won't know what a "separate attention routine" means. We should try to think of new naming here; I would prefer something like:

{attn_backend_i.__name__} generally performs better when using
full cuda-graphs for only pure decode batches (and using piecewise
cuda-graphs for prefill and mixed-prefill-decode iterations) to enable
this behavior turn on 
CompilationConfig.mixed_full_cudagraphs_and_piecewise_cudagraphs

some names to consider instead of separate_attention_routine to communicate this flag is cudagraph related:
mixed_full_cudagraphs_and_piecewise_cudagraphs little long but leaves the door open to mixing full-cudagraphs with piecewise-cudagraphs if we want to dispatch on something other than is_pure_decode (see comment above)
full_cudagraphs_for_pure_decode_only little shorter and aligns better with current behavior but leaves little room to expand on this behavior in the future

# for full cudagraph, select between mixed batches
# or pure decode batches
decode_case = self.compilation_config.separate_attention_routine\
and is_pure_decode
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should get rid of is_pure_decode here. I could see cases in the future where we might want to mixed piecewise + full-cudagraphs but have a different heuristic than is_pure_decode. e.g. for FA3 we may want to run mixed small decodes or spec-decode using full-cudagraphs but run large prefills using piecewise cudagraphs

Honestly I find all these flags very confusing; I think I much simpler more extensible dispatch logic would be:

dispatch_key = DispatchKey(num_reqs=..., num_tokens=..., uniform_batch=...)
if dispatch_key in self.full_cudagraphs:
     return self.full_cudagraphs[dispatch_key]
# Fall-back if a full_cudagraph isn't available 
return self.piecewise_cudagraph or self.model

the uniform_batch flag here would indicate that all of the requests in the batch have the same number of tokens; so if num_reqs == num_tokens and uniform_batch would be pure decode but this would leave a door open for spec-decode support the future i.e. where "decode" steps are validating 2-4ish tokens at the same time. So if we are speculator is set to speculate 3 tokens at a time we could create full-cudagraphs for 3*num_reqs == num_tokens and uniform_batch. Something like FlashMLA would actually support this since the main thing it wants is a uniform batch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Post from previous discussions.

You are right about more explicit control of full cudagraph like all, mixed, and decode_only. But I think an extra flag may not be necessary now, after the cudagraph_mode proposed in #20283, the extra control can be achieved by extending the enum mode.

Like use NONE=0, PIECEWISE=1, FULL=2, FULL_DECODE_ONLY=3, and FULL_AND_PIECEWISE=4. Here, NONE is for no cudagraph. PIECEWISE use only piecewise cudagraph (now v1 default). FULL means the current strategy for maximum full cudagraph support (with separate_attention_routine tunable to achieve mixed only or all). FULL_DECODE_ONLY uses only one set of cudagraph for pure decode, and no cudagraph for the rest. FULL_AND_PIECEWISE means explicitly having two sets of cudagraph, with full cudagraph for decode-only, and piecewise cudagraph for mixed batches or any rest. In this way, the separate_attention_routine is forced to true in FULL_DECODE_ONLY and FULL_AND_PIECEWISE, and the cascade attention can also be supported in these two modes.

I think @ProExpertProg and I agree to leave this explicit control of mixed_full_cudagraphs_and_piecewise_cudagraphs to a follow-up PR to include a new cudagraph_mode like FULL_AND_PIECEWISE. Currently, the FULL mode just maximizes the support of full cudagraph with proper fallbacks, and separate_attention_routine takes effect only in this mode to tell if we want to retain a unified routine or use separate routines for different situations.

But have to admit that separate_attention_routine seems a bit redundant now, as it would be overridden when attention backends cudagraph support is PURE_DECODE_ONLY or ALWAYS_UNIFIED, and would not be overridden only when ALWAYS_SEPARATE.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# Fall-back if a full_cudagraph isn't available 
return self.piecewise_cudagraph or self.model

I don't think we can now fully separate the piecewise_cudagraph and the raw model easily, since they are integrated together by the torch.compile with vllm piecewise compilation currently.

the uniform_batch flag here would indicate that all of the requests in the batch have the same number of tokens; so if num_reqs == num_tokens and uniform_batch would be pure decode but this would leave a door open for spec-decode support the future i.e. where "decode" steps are validating 2-4ish tokens at the same time. So if we are speculator is set to speculate 3 tokens at a time we could create full-cudagraphs for 3*num_reqs == num_tokens and uniform_batch. Something like FlashMLA would actually support this since the main thing it wants is a uniform batch.

I do agree we should leave a door for spec-decode, but I also think using the uniform_batch flag and num_tokens together for dispatching is somewhat confusing. First things is, if the speculator is set to speculate 3 tokens at a time, I guess this pattern is fixed, and we should just design a new enum representing that we are doing the speculate decode, rather than judging if 3*num_reqs == num_tokens and uniform_batch. Also, considering the meaning of uniform_batch you mentioned, that is not equivalent to pure decode and spec-decode. One counterexample may be in a pure prefill case where each request has 3 tokens.

Moreover, could we leave num_tokens being handled by the cudagraph wrapper itself? I think dispatching the num_tokens explicitly inside the dispatcher may be reasonable, but it is not trivial to manage all fused cases, and we couldn't reuse the uniform-designed cudagraph wrapper here. Leaving all that exact cudagraph management of one edge case to the wrapper would be good, as one cudagraph wrapper could just represent one case we want to process.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to leaving the token dispatching inside the wrapper.

Copy link
Collaborator

@LucasWilkinson LucasWilkinson Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can now fully separate the piecewise_cudagraph and the raw model easily, since they are integrated together by the torch.compile with vllm piecewise compilation currently.

this was just there as pseudo code to represent the logic; like you said the actual code would just be the model since the piecewise dispatching is under that (and dispatches to eager if the batch is too large)

One counterexample may be in a pure prefill case where each request has 3 tokens.

that would be ok; there would be no difference between that and a spec decode from the attention perspective so we would want it to use the the full-cudagraph if available in this case. We shouldn't fixate on the prefill/decode naming since chunked prefill and spec-decode blur these lines quite a bit; those names are just colloquially used to mean large query length or near 1 query length respectively. They are useful in the conversation but im a bit hesitant to harden those definitions (i.e. decode being query_len == 1) into the code; especially inside this fairly core part of the code.

Moreover, could we leave num_tokens being handled by the cudagraph wrapper itself?
+1 to leaving the token dispatching inside the wrapper.

I like the idea of having it as part of the dispatch keys because it leaves the door for have full-cudagraphs for only small batch sizes; like I could see a world where might only compile full cuda-graphs for up to BS 128 and then use piecewise for everything larger

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we already kinda support capturing full-cudagraphs only up to a size smaller than max_num_seqs via the configurable compilation_config.cudagraph_capture_sizes. So if that is the case then here would dispatch to the full-cudagraph wrapper only for it to fallback on eager inside the wrapper. I think this very confusing for readers/users. I think we should try to use this opportunity to flatten and simplify the dispatching; as someone who is moderately familiar with cudagraphs I vLLM I find all this dispatching very confusing.

This is where having a registry of the cudagraphs captured and keys representing the workloads they support I think could be far less confusing i.e. we can try to dispatch to a full-cudagraph first if one doesnt exist dispatch to a piecewise-cudagraph and if one doesnt exist dispatch to eager (I am aware the last 2 steps are currently happening in piecewise backend)

While I do see how later we might want to do more complex dispatching and size will certainly have to be involved, I think that's out of scope for this PR.

ya I don't think we have to go that far in this PR I just really want to make sure we are creating the right extensible abstractions since this PR is introducing alot of code that would have to get refactored to enable this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just really want to make sure we are creating the right extensible abstractions since this PR

I agree with that & thanks for driving this point. What about the following "compromise"/alternate solution: the runtime style (or mode if we consolidate) and dispatch key dispatch separately. The runtime style is used to select between the CUDAGraphWrapper instances (full and piecewise), and the DispatchKey dispatches between recorded cudagraphs within a wrapper, for now only including num_tokens and uniform_batch. This way

Structure would look like this:

  • CUDAGraphDispatcher decides the "runtime style" and the DispatchKey (whatever new name it receives), and sets it in the forward context
  • Each CUDAGraphWrapper inspects the forward context and only captures/replays if runtime style matches. It uses the DispatchKey to decide which cudagraph to capture/replay.

This solves the issue where we can't do cudagraph capture/replay directly in the dispatcher for piecewise cudagraphs. While piecewise cudagraphs might not need as much detailed dispatching, this would give us the flexibility to dispatch to different routines for ops that were not in splitting ops.

@fhl2000 no need to implement this yet, let's reach a consensus on this first

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this approach, the CUDAGraphWrapper instances could have even less logic and blindly trust the forward context on what cg to dispatch to; if it doesn't exist yet, it gets captured. That way the dispatcher is the single source of truth on available cudagraphs.

If I'm missing something, a "single source of truth" for "available" cudagraphs a new noinit dictionary on CompilationConfig, used by both the dispatcher and the CUDAGraphWrapper instances.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Structure would look like this:

  • CUDAGraphDispatcher decides the "runtime style" and the DispatchKey (whatever new name it receives), and sets it in the forward context
  • Each CUDAGraphWrapper inspects the forward context and only captures/replays if runtime style matches. It uses the DispatchKey to decide which cudagraph to capture/replay.

Looks good to me.

With this approach, the CUDAGraphWrapper instances could have even less logic and blindly trust the forward context on what cg to dispatch to; if it doesn't exist yet, it gets captured. That way the dispatcher is the single source of truth on available cudagraphs.

If I'm missing something, a "single source of truth" for "available" cudagraphs a new noinit dictionary on CompilationConfig, used by both the dispatcher and the CUDAGraphWrapper instances.

Sharing a new noinit dictionary on CompilationConfig by both the dispatcher and the CUDAGraphWrapper instances seems unviable to me (or should be improved). While sharing DispatchKey on FULL style is possible as one graph item is enough for one dispatch key, it is not good for piecewise cudagraph as there are many graph items (almost one for a layer) corresponding to one DispatchKey in this case.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah sorry phrased poorly/dropped a sentence. I was saying we could use that dictionary if we need a single source of truth but hopefully we don't. I don't think dispatch key cares about any layer wise info so the different piecewise backends (one per subgraph) don't need to worry about it.

Copy link

mergify bot commented Jul 15, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fhl2000.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 15, 2025
Comment on lines +4012 to +4024
class CUDAGraphMode(enum.Enum):
# constants for the config of the cudagraph mode.
NONE = 0
PIECEWISE = 1
FULL = 2


class CUDAGraphRuntimeStyle(enum.Enum):
# constants for concrete cudagraph runtime style, used for
# runtime dispatching.
NONE = 0
PIECEWISE = 1
FULL = 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have both?

Copy link
Collaborator

@zou3519 zou3519 Jul 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am pretty confused about the CUDAGraphMode vs CUDAGraphRuntimeStyle. Is the reason that CUDAGraphMode=Full for FAv2 means that we do FULL for prefill but PIECEWISE for decode?

If so, would anyone want to do piecewise for prefill and piecewise for decode for FAv2?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 why do we have 2 identical enums

Copy link
Contributor Author

@fhl2000 fhl2000 Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have both?
+1 why do we have 2 identical enums

Sorry for causing confusion. But as explained in the comments in the code and also in the history chats, the CUDAGraphMode is intended for configure the behavior of modes. This mode configuration would be extended in future PR to contain more, for example FULL_ONLY_DECODE, FULL_AND_PIECEWISE, or something like AUTO as mentioned in #20283 that automatically selects among the previous modes. See below for the meaning of these modes

Like use NONE=0, PIECEWISE=1, FULL=2, FULL_DECODE_ONLY=3, and FULL_AND_PIECEWISE=4. Here, NONE is for no cudagraph. PIECEWISE use only piecewise cudagraph (now v1 default). FULL means the current strategy for maximum full cudagraph support (with separate_attention_routine tunable to achieve mixed only or all). FULL_DECODE_ONLY uses only one set of cudagraph for pure decode, and no cudagraph for the rest. FULL_AND_PIECEWISE means explicitly having two sets of cudagraph, with full cudagraph for decode-only, and piecewise cudagraph for mixed batches or any rest.

On the other side, CUDAGraphRuntimeStyle would be the actual style of cudagraphs we selected to run at runtime. I think there are only three styles to be shared among all possibilities, and can't be extended. This is also used as a property assigned for the CUDAGraph wrapper class for correctly activating cudagraphs of the right style, because currently we could have nested CUDAGraph wrappers, i.e., piecewise cudagraph wrapper integrated with the piecewise compiled model inside, while one wrapper wrapped outside for full cudagraph.

They have members of the same name coincidentally now, but should look fine after extending the cudagraph mode.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think mode can be enough and we just add asserts that during runtime we're not using any of the "mixed" modes that aren't valid as runtime styles.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am pretty confused about the CUDAGraphMode vs CUDAGraphRuntimeStyle. Is the reason that CUDAGraphMode=Full for FAv2 means that we do FULL for prefill but PIECEWISE for decode?

No. For FA2, when CUDAGraphMode is set to Full, it means using FULL cudagraphs for both mixed prefill-decode stages and pure decode stages. However, since FA2's cudagraph support is marked as ALWAYS_SEPARATE, it prefers separate cudagraph routines for these two stages. Only when the separate_attention_routine is set to False, there will be a single FULL cudagraph for mixed prefill-decode batches, which is also compatible with pure decode scenarios.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, the current design of FULL mode just intends to do the max it can to support full cudagraph runtime style, while falling back to piecewise cudagraph or no cudagraph runtime style if any incompatible routine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think mode can be enough and we just add asserts that during runtime we're not using any of the "mixed" modes that aren't valid as runtime styles.

But I think fusing the usage of cudagraph mode for both the semantics of "mode" and "runtimestyle" would lead to more confusion.

Copy link
Collaborator

@ProExpertProg ProExpertProg Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I think fusing the usage of cudagraph mode for both the semantics of "mode" and "runtimestyle" would lead to more confusion.

I think it's a tradeoff. And I don't think the semantics are different enough. It's also easier to introduce a separate enum later if needed. To distinguish between them, all variables currently of type RuntimeStyle can have runtime_mode in the name instead of just name if you want the semantic meaning to be clearer.

@LucasWilkinson @zou3519 what do you think?

@mergify mergify bot removed the needs-rebase label Jul 17, 2025
fhl2000 and others added 3 commits July 17, 2025 23:59
Signed-off-by: fhl <2410591650@qq.com>
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Copy link

mergify bot commented Jul 18, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fhl2000.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 18, 2025
@ProExpertProg
Copy link
Collaborator

ProExpertProg commented Jul 18, 2025

Hey, sorry for the late response here. Lucas, Sage, and I discussed this at length yesterday and settled on an extension of the last proposal I made in a comment above. I then spent some more time thinking about the implementation and came up with this flowchart representing the code structure. I included it below as well as added some more context around it. Please let me know if you have any questions, and if you need any help. Also feel free to reuse any part of this comment in code comments or the PR description. We can also adapt it and add it documentation.

Motivation

The reason for another adjustment is this proposal from Lucas:

Honestly I find all these flags very confusing; I think I much simpler more extensible dispatch logic would be:

dispatch_key = DispatchKey(num_reqs=..., num_tokens=..., uniform_batch=...)
if dispatch_key in self.full_cudagraphs:
     return self.full_cudagraphs[dispatch_key]
# Fall-back if a full_cudagraph isn't available 
return self.piecewise_cudagraph or self.model

I agree that dispatching between multiple CUDAGraphWrapper instances just for those instances to do more lookups and dispatching is not ideal.

Proposal

My proposal is to pull dispatching fully out of CUDAGraphWrapper instances and make the CUDAGraphDispatcher fully responsible for it. This way the dispatcher is the single source of truth for available CUDAGraphs. It communicates with the wrappers via the forward context, and the wrappers’ logic is further simplified, as they only need to handle passthrough, collection and replay, without any fallback logic. The detailed structure is shown below. I tried to color-code the parts but we can adjust this later and include it in the documentation. The model execution stack is shown in the middle, starting at the various GPUModelRunner entry points and passing through the wrappers, compilation, and other layers (only relevant ones shown).

CUDAGraph-dispatch-final drawio

There are three main components that need to be adjusted: GPUModelRunner, CUDAGraphDispatcher, and CUDAGraphWrapper. The GPUModelRunner invokes the initialization in the dispatcher, initiates the profile_run and capture_run, and executes the model forward pass. CUDAGraphDispatcher acts as an oracle, having full information about available CUDAGraphs (inferred from attention backend support and config), and chooses which CUDAGraph to dispatch to (or run eager). CUDAGraphWrapper always performs one of three actions: nothing (passthrough), capture, or replay.

At the core of dispatching is the new CUDAGraphDispatchKey structure, which includes all of the info required to dispatch to a CUDAGraph. It contains num_tokens (already used for dispatching on main), as well as num_reqs and is_uniform (to help dispatch between different kinds of batches). The latter two are optional; for (attention) kernels that only care about num_tokens, they are set to None. E.g. this is the case when excluding attention from cuda graphs in piecewise mode. (Alternatively, we could use a Union[int, CUDAGraphDispatchKey] but I thought this is simpler, but feel free to deviate from this.)

The dispatcher stores two sets of dispatch keys, each corresponding to one of the modes/cudagraph wrappers and representing the keys for which cudagraphs are available for each mode. These sets are initialized inside init depending on attention support and what mode is enabled, for example:

  • Attention backend supports full cg ALWAYS_UNIFIED, full cudagraphs enabled:
    full_cuda_graphs=[(i, None, None) for i in config.cudagraph_capture_sizes]
  • Attention backend supports full cg PURE_DECODE_ONLY, full cudagraphs enabled:
    full_cuda_graphs=[(i, i, True) for i in config.cudagraph_capture_sizes]
    piecewise_cuda_graphs=[(i, None, None) for i in config.cudagraph_capture_sizes]
  • Attention backend supports full cg ALWAYS_SEPARATE, full cudagraphs enabled:
    full_cuda_graphs=[(i, i, True) for i in config.cudagraph_capture_sizes] +
    [(i, None, None) for i in config.cudagraph_capture_sizes]
  • Piecewise cudagraphs enabled:
    piecewise_cuda_graphs=[(i, None, None) for i in config.cudagraph_capture_sizes]

Init is actually the place where we can get much fancier in the future and “prepare” all kinds of cudagraph combos. And the dispatch, capture and replay can all stay the same as they all just blindly follow what’s available.

When dispatching before model execution, the dispatcher looks inside its sets for the dispatch key corresponding to the batch, and returns the key and mode corresponding to a valid combination and None, None otherwise. It prioritizes full cudagraphs and specific keys over general/wildcard keys. One important thing here is that the returned key might be more general than the key passed in; the returned key is what should be put in the forward context as the wrapper will just follow it blindly and not do any more generalization/dispatching.

The wrapper’s responsibilities are further reduced from what’s currently in the PR. It now only needs to check whether it’s activated (via mode) and if not, passes through, otherwise it replays an available cudagraph for a key or captures it if it doesn’t exist. This way, there’s no implicit contract between the dispatcher and the wrapper about, and instead the wrapper directly trusts what’s in the forward context. The simple existence check also makes it easy to support lazy capture in the future.

One note about the GPUModelRunner is that I did not include the _dummy_run utility in the diagram. That was only done to simplify the diagram and because I wasn’t sure what its exact signature should be. Let’s just try to reduce the amount of flags passed and hopefully limit the cudagraph & attention-related ones to CUDAGraphDispatchKey and CUDAGraphMode - no problem if not possible.

I think it might also be possible to remove build_for_cudagraph_capture and can_run_in_cudagraph, but we can also do that later.

Required changes (from what’s currently on this PR)

  • Slight simplification of CUDAGraphWrapper
  • Addition of CUDAGraphDispatchKey (open to a better name, maybe BatchDescriptor?)
  • Revamp of CUDAGraphDispatcher: adding moving init to it and modifying the dispatch function (no longer needing access to the model)
  • Directly wrapping GPUModelRunner.model with CUDAGraphWrapper(mode=FULL)
  • Pruning as many of the current dispatch flags as possible (definitely from forward context, hopefully also from _dummy_run
  • I’m sure I missed some and there will be new implementation difficulties.
  • After we settle on the exact implementation, please add detailed docstrings to classes and methods, and add tests for the dispatching logic.

Lucas’s original proposal

Lucas proposed combining CUDAGraphWrapper(mode=FULL) and CUDAGraphDispatcher to make cudagraph dispatch, capture, and replay all in one place for full cudagraphs while restoring piecewise to the current implementation on main (dispatching only based on num_tokens for piecewise). While dispatch and capture/replay for full cudagraphs now all happen in one place, we still need to use the forward context to communicate with the piecewise wrapper to make sure it doesn’t interfere with the warmup+capture of the full cudagraphs. The piecewise wrapper also has to duplicate the capture & replay logic, making the implementation harder to understand in my opinion. While passing the CUDAGraphDispatchKey in the forward context introduces some distance between dispatch and capture/replay, I believe it is the simplest and most capable overall implementation with the cleanest separation of concerns.

@yinghai
Copy link
Contributor

yinghai commented Jul 19, 2025

My proposal is to pull dispatching fully out of CUDAGraphWrapper instances and make the CUDAGraphDispatcher fully responsible for it.

Cool this makes a lot of sense to me. Like the dispatcher managers the graph managers but the graph managers can be of different type (full, piecewise with torch.compile).

@mergify mergify bot removed the needs-rebase label Jul 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build rocm Related to AMD ROCm v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants